{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建语音识别系统 - 解码与评测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何从模型输出到识别文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型输出的结果为 (b, len, vocab_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### greedy search\n",
    "\n",
    "每一步选取预测概率最大的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "def greedy_search(ctc_probs: torch.tensor, encoder_out_lens: torch.tensor):\n",
    "    batch_size, maxlen = ctc_probs.size()[:2]\n",
    "    topk_prob, topk_index = ctc_probs.topk(1, dim=2)\n",
    "    topk_index = topk_index.view(batch_size, maxlen)\n",
    "    encoder_out_lens = encoder_out_lens.view(-1).tolist()\n",
    "\n",
    "    hyps = []\n",
    "\n",
    "    for i in range(len(encoder_out_lens)):\n",
    "        hyp = topk_index[i, :encoder_out_lens[i]].tolist()\n",
    "        hyps.append(hyp)\n",
    "\n",
    "    return hyps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([46, 51, 44, 44, 41, 49, 48, 48, 74, 93, 44, 49, 50, 51, 58, 50])\n"
     ]
    }
   ],
   "source": [
    "tensordict = torch.load(\"./example2.pt\")\n",
    "\n",
    "pre = tensordict[\"pre\"].to(\"cpu\")\n",
    "lens = tensordict[\"lens\"].to(\"cpu\")\n",
    "\n",
    "print(lens)\n",
    "res = greedy_search(pre, lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 2, 5, 323, 5, 296, 5, 5, 75, 5, 243, 278, 278, 5, 394, 5, 5, 5, 51, 5, 5, 247, 5, 5, 360, 5, 364, 5, 5, 57, 5, 5, 5, 238, 122, 5, 65, 5, 5, 167, 271, 5, 5, 142, 5, 68, 5, 5, 5, 3, 3]\n",
      "['<sos>', '<sos>', '<blk>', 'liao', '<blk>', 'dian', '<blk>', '<blk>', 'pao', '<blk>', 'si', 'ji', 'ji', '<blk>', 'kong', '<blk>', '<blk>', '<blk>', 'cuan', '<blk>', '<blk>', 'hen', '<blk>', '<blk>', 'zong', '<blk>', 'shua', '<blk>', '<blk>', 'qie', '<blk>', '<blk>', '<blk>', 'nian', 'tian', '<blk>', 'wai', '<blk>', '<blk>', 'zhou', 'du', '<blk>', '<blk>', 'jiao', '<blk>', 'die', '<blk>', '<blk>', '<blk>', '<eos>', '<eos>']\n"
     ]
    }
   ],
   "source": [
    "from tokenizer.tokenizer import Tokenizer\n",
    "tokenizer = Tokenizer(\"./tokenizer/vocab.txt\")\n",
    "\n",
    "print(res[1])\n",
    "\n",
    "print(tokenizer.decode(res[1], ignore_special=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO：** 请大家根据CTC的解码思路，将模型的输出进行解码，移除重复字符和blank(上面的版本没有移除重复字符和blank)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ctc_decode(hyps, blank_id):\n",
    "    \"\"\"\n",
    "    实现CTC解码，移除重复字符和blank\n",
    "    \n",
    "    Args:\n",
    "        hyps: 模型输出的预测序列列表\n",
    "        blank_id: blank标记的ID\n",
    "        \n",
    "    Returns:\n",
    "        解码后的序列列表\n",
    "    \"\"\"\n",
    "    decoded_hyps = []\n",
    "    \n",
    "    for hyp in hyps:\n",
    "        decoded = []\n",
    "        prev = -1\n",
    "        \n",
    "        for token_id in hyp:\n",
    "            if token_id == blank_id:\n",
    "                continue\n",
    "            \n",
    "            if token_id != prev:\n",
    "                decoded.append(token_id)\n",
    "            \n",
    "            prev = token_id\n",
    "        \n",
    "        decoded_hyps.append(decoded)\n",
    "    \n",
    "    return decoded_hyps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "原始预测：\n",
      "[2, 2, 5, 323, 5, 296, 5, 5, 75, 5, 243, 278, 278, 5, 394, 5, 5, 5, 51, 5, 5, 247, 5, 5, 360, 5, 364, 5, 5, 57, 5, 5, 5, 238, 122, 5, 65, 5, 5, 167, 271, 5, 5, 142, 5, 68, 5, 5, 5, 3, 3]\n",
      "['<sos>', '<sos>', '<blk>', 'liao', '<blk>', 'dian', '<blk>', '<blk>', 'pao', '<blk>', 'si', 'ji', 'ji', '<blk>', 'kong', '<blk>', '<blk>', '<blk>', 'cuan', '<blk>', '<blk>', 'hen', '<blk>', '<blk>', 'zong', '<blk>', 'shua', '<blk>', '<blk>', 'qie', '<blk>', '<blk>', '<blk>', 'nian', 'tian', '<blk>', 'wai', '<blk>', '<blk>', 'zhou', 'du', '<blk>', '<blk>', 'jiao', '<blk>', 'die', '<blk>', '<blk>', '<blk>', '<eos>', '<eos>']\n",
      "\n",
      "CTC解码后：\n",
      "[2, 323, 296, 75, 243, 278, 394, 51, 247, 360, 364, 57, 238, 122, 65, 167, 271, 142, 68, 3]\n",
      "['<sos>', 'liao', 'dian', 'pao', 'si', 'ji', 'kong', 'cuan', 'hen', 'zong', 'shua', 'qie', 'nian', 'tian', 'wai', 'zhou', 'du', 'jiao', 'die', '<eos>']\n"
     ]
    }
   ],
   "source": [
    "blank_id = tokenizer.blk_id()\n",
    "decoded_res = ctc_decode(res, blank_id)\n",
    "\n",
    "print(\"原始预测：\")\n",
    "print(res[1])\n",
    "print(tokenizer.decode(res[1], ignore_special=False))\n",
    "\n",
    "print(\"\\nCTC解码：\")\n",
    "print(decoded_res[1])\n",
    "print(tokenizer.decode(decoded_res[1], ignore_special=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评测识别结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里我们采用字错率(CER, character error rate)来评测ASR系统的性能，计算公式如下:\n",
    "\n",
    "$$CER = \\frac{S+D+I}{N}$$\n",
    "\n",
    "pre 代表模型预测， gt 代表正确识别结果。与最小编辑距离一致，将pre转化成gt，其中，S代表将 pre 转化成 gt 需要替换的数量，D 代表将 pre转化成 gt 需要删除的数量，I 代表将 pre 转化成 gt 需要插入的数量，N 代表gt 的长度。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**TODO：** 根据最小编辑距离求出 S，D，I，N ，完成ASR的CER指标评测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataloader import get_dataloader\n",
    "from model.model import CTCModel\n",
    "import torch\n",
    "from utils.utils import to_device\n",
    "from tqdm import tqdm\n",
    "\n",
    "dev_dataloader = get_dataloader(\"./dataset/split/dev/wav.scp\", \"./dataset/split/dev/pinyin\", 32, tokenizer, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edit_distance(ref, hyp):\n",
    "    \"\"\"\n",
    "    计算两个序列之间的编辑距离，并返回替换、删除、插入的具体数量\n",
    "    \n",
    "    Args:\n",
    "        ref: 参考序列（正确文本）\n",
    "        hyp: 预测序列（识别结果）\n",
    "        \n",
    "    Returns:\n",
    "        距离值，替换数，删除数，插入数\n",
    "    \"\"\"\n",
    "    n = len(ref)\n",
    "    m = len(hyp)\n",
    "    dp = [[0 for _ in range(m+1)] for _ in range(n+1)]\n",
    "    \n",
    "    for i in range(n+1):\n",
    "        dp[i][0] = i\n",
    "    for j in range(m+1):\n",
    "        dp[0][j] = j\n",
    "\n",
    "    for i in range(1, n+1):\n",
    "        for j in range(1, m+1):\n",
    "            if ref[i-1] == hyp[j-1]:\n",
    "                dp[i][j] = dp[i-1][j-1]\n",
    "            else:\n",
    "                dp[i][j] = min(dp[i-1][j-1] + 1,   # 替换\n",
    "                               dp[i-1][j] + 1,     # 删除\n",
    "                               dp[i][j-1] + 1)     # 插入\n",
    "\n",
    "    i, j = n, m\n",
    "    s_count = d_count = i_count = 0\n",
    "    \n",
    "    while i > 0 or j > 0:\n",
    "        if i > 0 and j > 0 and ref[i-1] == hyp[j-1]:\n",
    "            i -= 1\n",
    "            j -= 1\n",
    "        elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:\n",
    "            # 替换\n",
    "            s_count += 1\n",
    "            i -= 1\n",
    "            j -= 1\n",
    "        elif i > 0 and dp[i][j] == dp[i-1][j] + 1:\n",
    "            # 删除\n",
    "            d_count += 1\n",
    "            i -= 1\n",
    "        else:\n",
    "            # 插入\n",
    "            i_count += 1\n",
    "            j -= 1\n",
    "    \n",
    "    return dp[n][m], s_count, d_count, i_count\n",
    "\n",
    "def calculate_cer(references, hypotheses):\n",
    "    \"\"\"\n",
    "    计算字错率(CER)\n",
    "    \n",
    "    Args:\n",
    "        references: 参考文本列表\n",
    "        hypotheses: 识别结果列表\n",
    "        \n",
    "    Returns:\n",
    "        CER值，以及替换、删除、插入的总数和参考文本总长度\n",
    "    \"\"\"\n",
    "    total_distance = 0\n",
    "    total_subs = 0\n",
    "    total_dels = 0\n",
    "    total_ins = 0\n",
    "    total_ref_length = 0\n",
    "    \n",
    "    for ref, hyp in zip(references, hypotheses):\n",
    "        distance, subs, dels, ins = edit_distance(ref, hyp)\n",
    "        total_distance += distance\n",
    "        total_subs += subs\n",
    "        total_dels += dels\n",
    "        total_ins += ins\n",
    "        total_ref_length += len(ref)\n",
    "\n",
    "    cer = (total_subs + total_dels + total_ins) / total_ref_length if total_ref_length > 0 else 1.0\n",
    "    \n",
    "    return cer, total_subs, total_dels, total_ins, total_ref_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "评估中: 100%|██████████| 32/32 [00:02<00:00, 12.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "评测结果:\n",
      "替换(S): 1672, 删除(D): 445, 插入(I): 76, 参考长度(N): 19252\n",
      "CER: 0.1139 (2193/19252)\n",
      "\n",
      "样本对比:\n",
      "参考: ['yi', 'ge', 'nan', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guang', 'dai', 'zhi']\n",
      "预测: ['yi', 'gen', 'nan', 'ren', 'tui', 'ran', 'de', 'zuo', 'zai', 'pang', 'bian', 'mu', 'guan', 'dai', 'zhi']\n",
      "\n",
      "参考: ['xi', 'huan', 'ba', 'li', 'ao', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']\n",
      "预测: ['xi', 'huan', 'ba', 'li', 'de', 'shu', 'cha', 'zai', 'niu', 'zai', 'ku', 'de', 'qian', 'mian']\n",
      "\n",
      "参考: ['zha', 'yan', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'qu', 'hao', 'ling', 'e', 'er', 'ling']\n",
      "预测: ['zhan', 'yan', 'yi', 'kan', 'xiang', 'qi', 'de', 'shi', 'guang', 'zhou', 'de', 'xu', 'hao', 'liu', 'e', 'er', 'ling']\n",
      "\n",
      "参考: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'ceng', 'zao', 'yu', 'guo', 'tong', 'yang', 'de', 'wei', 'ji']\n",
      "预测: ['ci', 'qian', 'qing', 'hua', 'zi', 'guang', 'jiu', 'cun', 'zao', 'yu', 'guo', 'tong', 'yao', 'de', 'wei', 'ji']\n",
      "\n",
      "参考: ['bu', 'tai', 'ming', 'bai', 'shen', 'me', 'yi', 'si']\n",
      "预测: ['bu', 'tai', 'ming', 'bai', 'shen', 'men', 'yi', 'si']\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def evaluate_model(dataloader, model, tokenizer, device='cpu'):\n",
    "    model.eval()\n",
    "    all_refs = []\n",
    "    all_hyps = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(dataloader, desc=\"评估中\"):\n",
    "            batch = to_device(batch, device)\n",
    "            audios = batch['audios']\n",
    "            audio_lens = batch['audio_lens']\n",
    "            texts = batch['texts']\n",
    "            text_lens = batch['text_lens']\n",
    "            \n",
    "            encoder_out, _, encoder_out_lens = model(audios, audio_lens, texts, text_lens)\n",
    "            hyps = greedy_search(encoder_out, encoder_out_lens)\n",
    "            decoded_hyps = ctc_decode(hyps, tokenizer.blk_id())\n",
    "            \n",
    "            for i in range(len(text_lens)):\n",
    "                ref = texts[i, :text_lens[i]].tolist()\n",
    "                all_refs.append(ref)\n",
    "                all_hyps.append(decoded_hyps[i])\n",
    "    \n",
    "    # 计算CER\n",
    "    cer, subs, dels, ins, ref_len = calculate_cer(all_refs, all_hyps)\n",
    "    \n",
    "    print(f\"评测结果:\")\n",
    "    print(f\"替换(S): {subs}, 删除(D): {dels}, 插入(I): {ins}, 参考长度(N): {ref_len}\")\n",
    "    print(f\"CER: {cer:.4f} ({subs+dels+ins}/{ref_len})\")\n",
    "    \n",
    "    print(\"\\n样本对比:\")\n",
    "    for i in range(min(5, len(all_refs))):\n",
    "        print(f\"参考: {tokenizer.decode(all_refs[i])}\")\n",
    "        print(f\"预测: {tokenizer.decode(all_hyps[i])}\")\n",
    "        print()\n",
    "    \n",
    "    return cer\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = CTCModel(80, 256, tokenizer.size(), tokenizer.blk_id()).to(device)\n",
    "\n",
    "checkpoint = torch.load(\"./model.pt\", map_location=device)\n",
    "model.load_state_dict(checkpoint['model'])\n",
    "\n",
    "cer = evaluate_model(dev_dataloader, model, tokenizer, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
